{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test the neural network policy in the training environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Method definition info(Any...) in module Base at util.jl:532 overwritten"
     ]
    }
   ],
   "source": [
    "using POMDPs, StatsBase, POMDPToolbox, QMDP, RLInterface, AutomotiveDrivingModels, AutoViz, SARSOP, Images, PyCall, Reel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " in module Logging at C:\\Users\\Maxime\\.julia\\v0.6\\Logging\\src\\Logging.jl:115.\n",
      "WARNING: Method definition warn(Any...) in module Base at util.jl:585 overwritten in module Logging at C:\\Users\\Maxime\\.julia\\v0.6\\Logging\\src\\Logging.jl:115.\n",
      "18-Oct 11:59:30:WARNING:root:replacing docs for 'scene_to_states :: Tuple{Records.Frame{Records.Entity{AutomotiveDrivingModels.VehicleState,AutomotiveDrivingModels.VehicleDef,Int64}},OCPOMDP}' in module 'Main'.\n",
      "18-Oct 11:59:30:WARNING:root:replacing docs for 'states_to_scene :: Tuple{Dict{Int64,OCState},OCPOMDP}' in module 'Main'.\n"
     ]
    }
   ],
   "source": [
    "include(\"occluded_crosswalk_env.jl\")\n",
    "include(\"pomdp_types.jl\")\n",
    "include(\"spaces.jl\")\n",
    "include(\"transition.jl\")\n",
    "include(\"observation.jl\")\n",
    "include(\"belief.jl\")\n",
    "include(\"decomposition.jl\")\n",
    "include(\"adm_helpers.jl\")\n",
    "include(\"render_helpers.jl\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "function POMDPs.generate_o(pomdp::OCPOMDP, s::OCState, rng::AbstractRNG)\n",
    "    o = generate_o(pomdp, s, OCAction(0.), s, rng)\n",
    "    return o\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "preprocess_o (generic function with 2 methods)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function preprocess_o(o::OCObs, pomdp::OCPOMDP, queue=nothing)\n",
    "    nframes = 10\n",
    "    o_mat = convert_o(Vector{Float64}, o, pomdp)\n",
    "    o_mat = reshape(o_mat, (size(o_mat)...))\n",
    "    if queue == nothing\n",
    "        w, h, nch = size(o_mat)\n",
    "        o_stacked = repeat(reshape(o_mat, (w, h, nch)), outer=(1,1,nframes))\n",
    "        o_stacked = reshape(o_stacked, (w, h, nch*nframes))\n",
    "        return o_stacked\n",
    "    else\n",
    "        queue = efficient_dequeue(queue, o_mat)\n",
    "    end\n",
    "    return queue\n",
    "end\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "efficient_dequeue"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"\n",
    "    given a queue of fix length, enqueu new_elem and dequeue the oldest element\n",
    "    without allocating new memory\n",
    "\"\"\"\n",
    "# function efficient_dequeue!(queue::Array{Float64,3}, new_elem::Array{Float64,3}, nch::Int64=size(new_elem,3),nqueue::Int64=div(size(queue, 3),size(new_elem, 3)))\n",
    "#     # first shift all the old element\n",
    "#     for i=1:nqueue-1\n",
    "#         queue[:,:,nch*(i-1)+1:nch*i] = queue[:,:,nch*i:nch*(i+1)-1]\n",
    "#     end\n",
    "#     # enqueue the last one\n",
    "#     queue[:,:,nch*(nqueue-1)+1:nch*nqueue] = new_elem\n",
    "#     return queue\n",
    "# end\n",
    "\n",
    "function efficient_dequeue(queue::Array{Float64,3}, new_elem::Array{Float64,3}, nch::Int64=size(new_elem,3),nqueue::Int64=div(size(queue, 3),size(new_elem, 3)))\n",
    "    # first shift all the old element\n",
    "    queue = circshift(queue, (0,0,-nch))\n",
    "    # enqueue the first one\n",
    "    queue[:,:,(nqueue-1)*nch+1:nqueue*nch] = new_elem\n",
    "    return queue\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MersenneTwister(UInt32[0x00000001], Base.dSFMT.DSFMT_state(Int32[1749029653, 1072851681, 1610647787, 1072862326, 1841712345, 1073426746, -198061126, 1073322060, -156153802, 1073567984  …  1977574422, 1073209915, 278919868, 1072835605, 1290372147, 18858467, 1815133874, -1716870370, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 382)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pomdp = OCPOMDP()\n",
    "pomdp.p_birth = 0.3\n",
    "pomdp.pos_res = 0.5\n",
    "pomdp.vel_res = 0.5\n",
    "pomdp.pos_obs_noise = 0.5\n",
    "pomdp.vel_obs_noise = 0.5\n",
    "rng = MersenneTwister(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PyObject <module 'dqn.nn_wrapper' from 'C:\\\\Users\\\\Maxime\\\\OneDrive - Leland Stanford Junior University\\\\Research\\\\policy-correction\\\\dqn\\\\nn_wrapper.py'>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@pyimport tensorflow as tf\n",
    "nn_wrapper = pyimport(\"dqn.nn_wrapper\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ../../dqn/test16/model.ckpt\r\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-10-18 11:59:43.536158: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536196: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE2 instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536203: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536209: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536214: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536220: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536225: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.\n",
      "2017-10-18 11:59:43.536231: W c:\\tf_jenkins\\home\\workspace\\release-win\\m\\windows\\py\\35\\tensorflow\\core\\platform\\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.\n",
      "[2017-10-18 11:59:43,669] Restoring parameters from ../../dqn/test16/model.ckpt\r\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "PyObject <dqn.nn_wrapper.NNWrapper object at 0x000000002A8A7B70>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn = nn_wrapper[:NNWrapper](\"../../dqn/test16/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average reward: 0.94; Average # of steps: 27.68; # of crashes: 3; # of time out 0; # of successes 97\n"
     ]
    }
   ],
   "source": [
    "r_avg = 0\n",
    "step_avg = 0\n",
    "crash = 0\n",
    "time_out = 0\n",
    "success = 0\n",
    "n_ep = 100\n",
    "max_steps = 100\n",
    "dist = initialstate_distribution(pomdp)\n",
    "saved_states = OCState[]\n",
    "\n",
    "for i=1:n_ep    \n",
    "    disc = 1.0\n",
    "    r_total = 0.0    \n",
    "    s = rand(rng, dist)\n",
    "    a_map = actions(pomdp)\n",
    "    o_init = generate_o(pomdp, s, rng)\n",
    "    o = preprocess_o(o_init, pomdp)\n",
    "    step = 1\n",
    "\n",
    "    while !isterminal(pomdp, s) && step <= max_steps # TODO also check for terminal observation\n",
    "        a = a_map[nn[:action](o)+1]\n",
    "#         println(o, \"\\n\\n\")\n",
    "        sp, o_, r = generate_sor(pomdp, s, a, rng)\n",
    "#         println(o_)\n",
    "        push!(saved_states, sp)\n",
    "        r_total += disc*r\n",
    "\n",
    "        s = sp\n",
    "        op = preprocess_o(o_, pomdp, o)\n",
    "#         println(op)\n",
    "        \n",
    "        o = op\n",
    "\n",
    "#         disc *= discount(pomdp)\n",
    "        step += 1\n",
    "    end\n",
    "    r_avg += r_total\n",
    "    if r_total <= -1\n",
    "        crash += 1\n",
    "    elseif step >= 100\n",
    "        time_out += 1\n",
    "    else\n",
    "            success += 1\n",
    "    end\n",
    "    step_avg += step\n",
    "#     println(r_total)\n",
    "end\n",
    "r_avg /= n_ep\n",
    "step_avg /= n_ep\n",
    "println(\"Average reward: $r_avg; Average # of steps: $step_avg; # of crashes: $crash; # of time out $time_out; # of successes $success\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Visualize policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<video autoplay controls><source src=\"files/reel-17147174824786228528.webm?18028092960200373390\" type=\"video/webm\"></video>"
      ],
      "text/plain": [
       "Reel.Frames{MIME{Symbol(\"image/png\")}}(\"C:\\\\Users\\\\Maxime\\\\AppData\\\\Local\\\\Temp\\\\jl_7BA.tmp\", 0x0000000000000a6c, 2.0, nothing)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "duration, fps, render_states = animate_states(pomdp, saved_states)\n",
    "speed_factor = 1\n",
    "film = roll(render_states, fps = fps*speed_factor, duration = duration/speed_factor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.0",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
